/*
 * Copyright (C) 2024 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/trace_processor/perfetto_sql/intrinsics/table_functions/interval_intersect.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "perfetto/base/compiler.h"
#include "perfetto/base/logging.h"
#include "perfetto/base/status.h"
#include "perfetto/ext/base/status_or.h"
#include "perfetto/protozero/proto_decoder.h"
#include "perfetto/protozero/proto_utils.h"
#include "perfetto/trace_processor/basic_types.h"
#include "perfetto/trace_processor/status.h"
#include "protos/perfetto/trace_processor/metrics_impl.pbzero.h"
#include "src/trace_processor/containers/string_pool.h"
#include "src/trace_processor/db/column.h"
#include "src/trace_processor/db/table.h"
#include "src/trace_processor/perfetto_sql/intrinsics/table_functions/tables_py.h"
#include "src/trace_processor/util/status_macros.h"

namespace perfetto::trace_processor {
namespace tables {
IntervalIntersectTable::~IntervalIntersectTable() = default;
}  // namespace tables

namespace {

using RepeatedDecoder = protos::pbzero::RepeatedBuilderResult::Decoder;
using RepeatedIter = ::protozero::PackedRepeatedFieldIterator<
    ::protozero::proto_utils::ProtoWireType::kFixed64,
    int64_t>;

base::StatusOr<RepeatedIter> DecodeArgument(const SqlValue& raw_arg,
                                            const char* debug_name,
                                            bool& parse_error) {
  if (raw_arg.type != SqlValue::kBytes) {
    return base::ErrStatus(
        "interval_intersect: '%s' should be a repeated field", debug_name);
  }
  protos::pbzero::ProtoBuilderResult::Decoder proto_arg(
      static_cast<const uint8_t*>(raw_arg.AsBytes()), raw_arg.bytes_count);
  if (!proto_arg.is_repeated()) {
    return base::ErrStatus(
        "interval_intersect: '%s' is not generated by RepeatedField "
        "function",
        debug_name);
  }

  auto iter =
      protos::pbzero::RepeatedBuilderResult::Decoder(proto_arg.repeated())
          .int_values(&parse_error);
  if (parse_error) {
    return base::ErrStatus(
        "interval_intersect: error when parsing '%s' values.", debug_name);
  }

  return iter;
}

struct Interval {
  int64_t id;
  int64_t ts;
  int64_t dur;

  int64_t end() { return ts + dur; }
};

struct IntervalsIterator {
  RepeatedIter ids;
  RepeatedIter tses;
  RepeatedIter durs;

  static base::StatusOr<IntervalsIterator> Create(
      const SqlValue& raw_ids,
      const SqlValue& raw_timestamps,
      const SqlValue& raw_durs,
      bool& parse_error) {
    ASSIGN_OR_RETURN(RepeatedIter ids,
                     DecodeArgument(raw_ids, "ids", parse_error));
    ASSIGN_OR_RETURN(RepeatedIter tses,
                     DecodeArgument(raw_timestamps, "timestamps", parse_error));
    ASSIGN_OR_RETURN(RepeatedIter durs,
                     DecodeArgument(raw_durs, "durations", parse_error));

    return IntervalsIterator{ids, tses, durs};
  }

  void operator++() {
    PERFETTO_DCHECK(ids && tses && durs);
    ids++;
    tses++;
    durs++;
  }

  Interval operator*() const { return Interval{*ids, *tses, *durs}; }

  explicit operator bool() const { return bool(ids); }
};

}  // namespace
IntervalIntersect::IntervalIntersect(StringPool* pool) : pool_(pool) {}
IntervalIntersect::~IntervalIntersect() = default;

Table::Schema IntervalIntersect::CreateSchema() {
  return tables::IntervalIntersectTable::ComputeStaticSchema();
}

std::string IntervalIntersect::TableName() {
  return tables::IntervalIntersectTable::Name();
}

uint32_t IntervalIntersect::EstimateRowCount() {
  // TODO(mayzner): Give proper estimate.
  return 1024;
}

base::StatusOr<std::unique_ptr<Table>> IntervalIntersect::ComputeTable(
    const std::vector<SqlValue>& args) {
  PERFETTO_DCHECK(args.size() == 6);

  // If either of the provided sets of columns is empty return.
  auto pred = [](const SqlValue& val) { return val.is_null(); };
  if (std::any_of(args.begin(), args.end(), pred)) {
    // We expect that either all left table values are empty or all right table
    // values are empty.
    if (std::all_of(args.begin(), args.begin() + 3, pred) ||
        std::all_of(args.begin() + 3, args.begin() + 6, pred)) {
      return std::unique_ptr<Table>(
          std::make_unique<tables::IntervalIntersectTable>(pool_));
    }
    return base::ErrStatus(
        "interval_intersect: not all of the arguments of one of the tables are "
        "null");
  }

  bool parse_error = false;
  ASSIGN_OR_RETURN(
      IntervalsIterator l_it,
      IntervalsIterator::Create(args[0], args[1], args[2], parse_error));
  ASSIGN_OR_RETURN(
      IntervalsIterator r_it,
      IntervalsIterator::Create(args[3], args[4], args[5], parse_error));

  // If there are no intervals in one of the tables then there are no intervals
  // returned.
  if (!l_it || !r_it) {
    return std::unique_ptr<Table>(
        std::make_unique<tables::IntervalIntersectTable>(pool_));
  }

  // We copy |l_it| and |r_it| for the second for loop.
  IntervalsIterator l_it_2 = l_it;
  IntervalsIterator r_it_2 = r_it;

  auto table = std::make_unique<tables::IntervalIntersectTable>(pool_);

  // Find all intersections where interval from right table started duringan
  // interval from left table.
  for (Interval l_i = *l_it; l_it && r_it && !parse_error;
       ++l_it, l_i = *l_it) {
    // If the next |r_i| starts after |l_i| ends, that means that we need to
    // go the the next |l_i|, so we need to exit the loop.
    for (Interval r_i = *r_it; r_it && r_i.ts < l_i.end() && !parse_error;
         ++r_it, r_i = *r_it) {
      // We already know (because we are in the loop) that |r_i| started before
      // |l_i| ended, we should not intersect only if |r_i| started before
      // |l_i|.
      if (r_i.ts < l_i.ts) {
        continue;
      }

      tables::IntervalIntersectTable::Row row;
      row.ts = std::max(r_i.ts, l_i.ts);
      row.dur = std::min(r_i.end(), l_i.end()) - row.ts;
      row.left_id = static_cast<uint32_t>(l_i.id);
      row.right_id = static_cast<uint32_t>(r_i.id);
      table->Insert(row);
    }
  }

  // Find all intersections where interval from the left table started during an
  // interval from right table.
  for (Interval r_i = *r_it_2; r_it_2 && l_it_2 && !parse_error;
       ++r_it_2, r_i = *r_it_2) {
    for (Interval l_i = *l_it_2; l_it_2 && l_i.ts < r_i.end() && !parse_error;
         ++l_it_2, l_i = *l_it_2) {
      // The only difference between this and above algorithm is not
      // intersecting if the intervals started at the same time. We do this to
      // prevent double counting intervals.
      if (l_i.ts <= r_i.ts) {
        continue;
      }

      tables::IntervalIntersectTable::Row row;
      row.ts = std::max(r_i.ts, l_i.ts);
      row.dur = std::min(r_i.end(), l_i.end()) - row.ts;
      row.left_id = static_cast<uint32_t>(l_i.id);
      row.right_id = static_cast<uint32_t>(r_i.id);
      table->Insert(row);
    }
  }

  if (parse_error) {
    return base::ErrStatus(
        "interval_intersect: Error in parsing of one of the arguments.");
  }

  return std::unique_ptr<Table>(std::move(table));
}

}  // namespace perfetto::trace_processor
